{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evade Signal Classification with FGSM\n",
    "\n",
    "This notebook shows the performance of a digital FGSM attack on an Automatic Modulation Classification neural network model.\n",
    "\n",
    "### Assumptions\n",
    "- The dataset wrangling has already been completed (and is provided here)\n",
    "- The implementation of model training on this dataset has already been completed\n",
    "- The plotting code has already been completed\n",
    "\n",
    "### Components Recreated in Tutorial\n",
    "- FGSM attack constrained by a power ratio\n",
    "\n",
    "### See Also\n",
    "The code in this tutorial is a stripped down version of the code in ``rfml.attack.fgsm`` that simplifies discussion.  Further detail can be provided by directly browsing the source files."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Install the library code and dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the library code\n",
    "!pip install git+https://github.com/brysef/rfml.git@1.0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting Includes\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# External Includes\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch.nn.functional import cross_entropy\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Internal Includes\n",
    "from rfml.data import Dataset, Encoder\n",
    "from rfml.data import build_dataset\n",
    "\n",
    "from rfml.nbutils import plot_acc_vs_spr, plot_confusion\n",
    "\n",
    "from rfml.nn.eval import compute_accuracy\n",
    "from rfml.nn.eval.confusion import _confusion_matrix\n",
    "from rfml.nn.F import energy\n",
    "from rfml.nn.model import build_model\n",
    "from rfml.nn.train import build_trainer, PrintingTrainingListener"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu = True       # Set to True to use a GPU for training\n",
    "fig_dir = None   # Set to a file path if you'd like to save the plots generated\n",
    "data_path = None # Set to a file path if you've downloaded RML2016.10A locally"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and Evaluate a Model on a Static Dataset\n",
    "\n",
    "See ``module_2`` for a more extensive description of this procedure.\n",
    "Here we just use the library for a quick implementation.\n",
    "In practice, the model training would be separated from the attack, but, is juxtaposed with the attack here for simplicity as it is self-contained for demonstration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test, val, le = build_dataset(\"RML2016.10a\", path=data_path)\n",
    "# as_numpy returns x,y and x is shape BxCxIQxN\n",
    "input_samples = val.as_numpy(le=le)[0].shape[3]\n",
    "model = build_model(model_name=\"CNN\", input_samples=input_samples, n_classes=len(le))\n",
    "trainer = build_trainer(strategy=\"standard\", max_epochs=3, gpu=gpu)\n",
    "trainer.register_listener(PrintingTrainingListener())\n",
    "trainer(model=model, training=train, validation=val, le=le)\n",
    "acc = compute_accuracy(model=model, data=test, le=le)\n",
    "\n",
    "print(\"Overall Testing Accuracy: {:.4f}\".format(acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Strip Down the Testing Data\n",
    "Strip the data down to only the highest SNR (18 dB in RML2016.10a).  This ensures that the classification accuracy would have been close to the highest and gives the best evaluation of the attack because it separates the decreased accuracy from being causes by low SNR vs being caused by the attack."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = test.df[\"SNR\"] >= 18\n",
    "dl = DataLoader(test.as_torch(le=le, mask=mask), shuffle=True, batch_size=512)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evading Signal Classification with Direct Access to the Classifier\n",
    "\n",
    "Here, we're simply going to sweep the **intensity** of the attack, denoted in this tutorial as $E_s/E_p$ or signal-to-perturbation ratio represented in dB.\n",
    "\n",
    "Recall that the equation for FGSM [Goodfellow et. al] is\n",
    "\n",
    "\\begin{equation}\n",
    "    x^* = x + \\text{sign}(\\nabla_X \\mathcal{L}(f(\\theta, X), y_s))\n",
    "\\end{equation}\n",
    "\n",
    "For the purposes of a clean implementation in software, we can split this equation into multiple steps.\n",
    "First, we must compute the *signed gradient*.\n",
    "\n",
    "\\begin{equation}\n",
    "    \\text{\"signed gradient\"} = \\text{sign}(\\nabla_X \\mathcal{L}(f(\\theta, X), y_s))\n",
    "\\end{equation}\n",
    "\n",
    "This is done by utilizing the model and PyTorch to backpropogate the gradient to the input.\n",
    "Once the gradient is known, we can apply a simple sign operation.\n",
    "\n",
    "Then, the signed gradient must be scaled to achieve an allowed intensity.\n",
    "Traditionally, CV literature uses a norm to constrain the perturbation (e.g. $\\left\\lVert {x - x^*} \\right\\rVert_p < \\epsilon$).\n",
    "In the context of wireless communications, it makes more sense to constrain the power ratio of the perturbation to the underlying signal as the absolute values of either the perturbation or signal generally do not matter.\n",
    "Recall that we've assumed $E_s$ is $1$ and therefore can scale the perturbation using the following equation [Flowers et. al, Sadeghi/Larson].\n",
    "\n",
    "\\begin{equation}\n",
    "    p = \\sqrt{\\frac{10^{\\frac{-E_s/E_p}{10}}}{2 \\times \\text{sps}}} \\times \\text{\"signed gradient\"}\n",
    "\\end{equation}\n",
    "\n",
    "Because the RML2016.10A dataset is not properly normalized for all examples, especially AM-SSB, we have to normalize the input data before running the attack to have a known power -- this is the only way to ensure that the attack intensity is correct.\n",
    "\n",
    "To complete the FGSM algorithm we simply need to add the perturbation to the original example.\n",
    "\n",
    "\\begin{equation}\n",
    "    \\text{\"adversarial example\"} = x + p\n",
    "\\end{equation}\n",
    "\n",
    "#### Citations\n",
    "\n",
    "##### Goodfellow et al\n",
    "\n",
    "Goodfellow, I., Shlens, J., and Szegedy, C. (2015).  Explaining and harnessing adversarial examples. In Int. Conf. on Learning Representations.\n",
    "\n",
    "##### Flowers et al\n",
    "\n",
    "Flowers, B., Buehrer,  R.  M.,  and Headley,  W. C. (2019). Evaluating adversarial evasion attacks in the context of wireless communications. IEEE Transactions on Information Forensics and Security, pages 1–1.\n",
    "\n",
    "##### Sadeghi/Larson\n",
    "\n",
    "Sadeghi, M. and Larsson, E. G. (2018). Adversarial attacks on deep-learning based radio signal classification.IEEEWireless Commun. Letters, pages 1–1."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Helper methods for the attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fgsm(x, y, input_size, net, spr, sps):\n",
    "    p = compute_signed_gradient(x=x, y=y, input_size=input_size, net=net)\n",
    "    p = scale_perturbation(sg=p, spr=spr, sps=sps)\n",
    "    \n",
    "    return x + p\n",
    "\n",
    "\n",
    "def compute_signed_gradient(x, y, input_size, net):\n",
    "    # Ensure that the gradient is tracked at the input, add some noise to avoid any\n",
    "    # actual zeros in the signal (dithering), and then ensure its the proper shape\n",
    "    x.requires_grad = True\n",
    "    _x = _dither(x)\n",
    "\n",
    "    # Put the inputs/outputs onto the most probable device that the model is currently\n",
    "    # on -- this could fail if the model gets split amongst multiple devies, but, that\n",
    "    # doesn't happen in this code.\n",
    "    _x = _x.to(net.device)\n",
    "    y = y.to(net.device)\n",
    "\n",
    "    # Perform forward/backward pass to get the gradient at the input\n",
    "    _y = net(_x)\n",
    "    loss = cross_entropy(_y, y)\n",
    "    loss.backward()\n",
    "\n",
    "    # Take the sign of the gradient that can be scaled later\n",
    "    ret = torch.sign(x.grad.data)\n",
    "\n",
    "    return ret\n",
    "\n",
    "\n",
    "def scale_perturbation(sg, spr, sps):\n",
    "    if spr == np.inf:\n",
    "        return sg * 0\n",
    "    multiplier = pow(10, -spr / 10.0)\n",
    "    multiplier = multiplier / (2 * sps)\n",
    "    multiplier = pow(multiplier, 0.5)\n",
    "\n",
    "    return sg * multiplier\n",
    "\n",
    "\n",
    "def _dither(x):\n",
    "    snr = 100\n",
    "    voltage = pow(pow(10.0, -snr / 10.0), 0.5)\n",
    "\n",
    "    noise = x.data.new(x.size()).normal_(0.0, voltage)\n",
    "    return x + noise\n",
    "\n",
    "\n",
    "def _normalize(x):\n",
    "    power = energy(x)\n",
    "    # Make the dimensions match because broadcasting is too magical to\n",
    "    # understand in its entirety... essentially want to ensure that we\n",
    "    # divide each channel of each example by the sqrt of the power of\n",
    "    # that channel/example pair\n",
    "    power = power.view([power.size()[0], power.size()[1], 1, 1])\n",
    "\n",
    "    return x / torch.sqrt(power)\n",
    "\n",
    "\n",
    "def _sanity_check(desired_spr, adv_x, x):\n",
    "    signal_power = energy(x)\n",
    "    perturbation_power = energy(adv_x - x)\n",
    "    _spr = 10*torch.log10(signal_power / perturbation_power)\n",
    "    _spr = _spr.detach().numpy().mean()\n",
    "    if np.abs(_spr - desired_spr) > 0.5:\n",
    "        raise RuntimeError(\"Calculated SPR and desired SPR does not match: \"\n",
    "                           \"Desired SPR={:0.2f}dB, Calculated SPR={:0.2f}dB, \"\n",
    "                           \"Signal Power={:0.2f}dB, Perturbation Power={:0.2f}dB\".format(\n",
    "                               desired_spr,\n",
    "                               _spr,\n",
    "                               10.0*np.log10(signal_power.detach().numpy().mean()),\n",
    "                               10.0*np.log10(perturbation_power.detach().numpy().mean()))\n",
    "                          )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Top Level control loop for the attack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Ensure that the model is in \"evaluation\" mode\n",
    "# -- Therefore batch normalization is not computed and dropout is not performed\n",
    "# -- Note: This is the cause of a lot of bugs\n",
    "model.eval()\n",
    "\n",
    "acc_vs_spr = list()\n",
    "sprs = list()\n",
    "\n",
    "for spr in np.linspace(50.0, 0.0, num=26):\n",
    "    right = 0\n",
    "    total = 0\n",
    "    \n",
    "    for x, y in dl:\n",
    "        # Before crafting an attack on X, ensure that it is normalized to have a\n",
    "        # specified power -- RML2016.10a fluctuates a bit, especially on AM-SSB,\n",
    "        # which will cause the attack intensity to be incorrect.\n",
    "        x = _normalize(x)\n",
    "        adv_x = fgsm(x, y, spr=spr, input_size=input_samples, sps=1, net=model)\n",
    "        \n",
    "        # Ensure that we've accurately represented the attack power\n",
    "        _sanity_check(desired_spr=spr, adv_x=adv_x, x=x)\n",
    "        \n",
    "        predictions = model.predict(adv_x)\n",
    "        right += (predictions == y).sum().item()\n",
    "        total += len(y)\n",
    "\n",
    "    acc = float(right) / total\n",
    "    acc_vs_spr.append(acc)\n",
    "    sprs.append(spr)\n",
    "\n",
    "fig = plot_acc_vs_spr(acc_vs_spr=acc_vs_spr,\n",
    "                      spr=sprs,\n",
    "                      title=\"Performance of a Digital FGSM Attack\"\n",
    "                     )\n",
    "if fig_dir is not None:\n",
    "    file_path = \"{fig_dir}/direct_access_fgsm.pdf\".format(fig_dir=fig_dir)\n",
    "    print(\"Saving Figure -> {file_path}\".format(file_path=file_path))\n",
    "    fig.savefig(file_path, format=\"pdf\", transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Constructing a Confusion Matrix\n",
    "\n",
    "The model breaks down in a way that generally intuitively makes sense to the expert eye with knowledge of the underlying signal formats.  The analog signals are confused for other analog signals, the PSK/QAM signals are confused for other PSK/QAMs with differing modulation orders, etc.  This behavior was examined more closely for targeted adversarial attacks in [Bair et al.] where it was found that the adversarial distance [Papernot et al.], or put more simply, the adversarial perturbation power required to \"transform\" 50% of the input examples from a source to target class, could be used in a spectral clustering algorithm to recover the relationships between the signal formats.\n",
    "\n",
    "##### Bair et al.\n",
    "\n",
    "Bair, S., Delvecchio, M., Flowers, B., Michaels, A. J., andHeadley, W. C. (2019). On the limitations of targeted adversarial evasion attacks against deep learning enabled modulation recognition. In ACM Workshop on Wireless Security and Machine Learning (WiseML 2019).\n",
    "\n",
    "#### Papernot et al.\n",
    "\n",
    "Papernot,  N.,  McDaniel,  P.,  Jha,  S.,  Fredrikson,  M.,Celik,  Z.  B.,  and  Swami,  A.  (2016).   The  limitations  of  deep  learningin  adversarial  settings.   InIEEE  European  Symposium  on  Security  andPrivacy (EuroS&P), pages 372–387. IEEE."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spr = 10  # dB\n",
    "\n",
    "predictions = list()\n",
    "labels = list()\n",
    "for x, y in dl:\n",
    "    # Before crafting an attack on X, ensure that it is normalized to have a\n",
    "    # specified power -- RML2016.10a fluctuates a bit, especially on AM-SSB,\n",
    "    # which will cause the attack intensity to be incorrect.\n",
    "    x = _normalize(x)\n",
    "    adv_x = fgsm(x, y, spr=spr, input_size=input_samples, sps=1, net=model)\n",
    "\n",
    "    # Ensure that we've accurately represented the attack power\n",
    "    _sanity_check(desired_spr=spr, adv_x=adv_x, x=x)\n",
    "\n",
    "    _predictions = model.predict(adv_x)\n",
    "    predictions.extend(_predictions.detach().numpy())\n",
    "    labels.extend(y)\n",
    "\n",
    "cmn = _confusion_matrix(predictions=predictions, labels=labels, le=le)\n",
    "\n",
    "title = \"Confusion Matrix with a {spr} dB FGSM Attack\".format(spr=spr)\n",
    "fig = plot_confusion(cm=cmn, labels=le.labels, title=title)\n",
    "if fig_dir is not None:\n",
    "    file_path = \"{fig_dir}/fgsm_confusion_matrix.pdf\"\n",
    "    print(\"Saving Figure -> {file_path}\".format(file_path=file_path))\n",
    "    fig.savefig(file_path, format=\"pdf\", transparent=True)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
